{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "import os\n",
    "os.chdir('E:\\GitHub\\QA-abstract-and-reasoning')\n",
    "sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from pgn.layers import BahdanauAttention, Encoder, Pointer\n",
    "from pgn.batcher import batcher\n",
    "from utils.saveLoader import load_embedding_matrix\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.config import VOCAB_PAD\n",
    "from utils.config_gpu import config_gpu\n",
    "config_gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run utils/params.py\n",
    "# 产生输入数据\n",
    "vocab = Vocab(VOCAB_PAD)\n",
    "ds =iter(batcher(vocab, params))\n",
    "enc_data, dec_data = next(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 以下是调试好的Decoder单元\n",
    "![](https://img-blog.csdn.net/20180809142518309?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Rob3JtYXMxOTk2/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, embedding_matrix, \n",
    "                 dec_units, batch_size, attention):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.batch_sz = batch_size\n",
    "        self.dec_units = dec_units\n",
    "\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, weights=[embedding_matrix],\n",
    "                                                   trainable=False)\n",
    "\n",
    "        self.cell = tf.keras.layers.GRUCell(units=self.dec_units,\n",
    "                                            recurrent_initializer='glorot_uniform')\n",
    "        \n",
    "        self.attention = attention\n",
    "        self.fc1 = tf.keras.layers.Dense(self.dec_units*2)\n",
    "        self.fc2 = tf.keras.layers.Dense(vocab_size)\n",
    "        \n",
    "    def call(self, dec_input,  # (batch_size, )\n",
    "             prev_dec_hidden,  # (batch_size, dec_units)\n",
    "             enc_output,  # (batch_size, enc_len, enc_units)\n",
    "             enc_pad_mask, # (batch_size, enc_len)\n",
    "             use_coverage=True,\n",
    "             prev_coverage=None):\n",
    "        # 得到词向量, output[2]\n",
    "        # dec_x (batch_size, embedding_dim)\n",
    "        dec_x = self.embedding(dec_input)\n",
    "        \n",
    "        # 应用GRU单元算出dec_hidden\n",
    "        # 注意cell 返回的state是一个列表，gru单元中为 [h] lstm [h, c]\n",
    "        # 所以这里用[dec_hidden] 取出来，这样dec_hidden就是tensor形式了\n",
    "        # dec_output (batch_size, dec_units)\n",
    "        # dec_hidden (batch_size, dec_units), output[1]\n",
    "        dec_output, [dec_hidden] = self.cell(dec_x, [prev_dec_hidden])\n",
    "        \n",
    "        # 计算注意力，得到上下文，注意力分布，coverage\n",
    "        # context_vector (batch_size, enc_units), output[0]\n",
    "        # attn (batch_size, enc_len), output[4]\n",
    "        # coverage (batch_size, enc_len, 1), output[5]\n",
    "        context_vector, attn, coverage = self.attention(dec_hidden,\n",
    "                                                        enc_output,\n",
    "                                                        enc_pad_mask,\n",
    "                                                        use_coverage,\n",
    "                                                        prev_coverage)\n",
    "\n",
    "        # 将上一循环的预测结果跟注意力权重值结合在一起作为本次的GRU网络输入\n",
    "        # dec_output (batch_size, enc_units + dec_units)\n",
    "        dec_output = tf.concat([dec_output, context_vector], axis=-1)\n",
    "\n",
    "        # 保持维度不变，其实我也不确定第一个全连接层的units该设置为多少\n",
    "        # pred (batch_size, enc_units + dec_units)\n",
    "        pred = self.fc1(dec_output)\n",
    "        \n",
    "        # pred (batch_size, vocab), output[3]\n",
    "        pred = self.fc2(pred)\n",
    "        \n",
    "        \"\"\"output\n",
    "        output[0]: context_vector (batch_size, dec_units)\n",
    "        output[1]: dec_hidden (batch_size, dec_units)\n",
    "        output[2]: dec_x (batch_size, embedding_dim)\n",
    "        output[3]: pred (batch_size, vocab_size)\n",
    "        output[4]: attn (batch_size, enc_len)\n",
    "        output[5]: coverage (batch_size, enc_len, 1)\n",
    "        \"\"\"\n",
    "        return context_vector, dec_hidden, dec_x, pred, attn, coverage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "P_{gen} = \\sigma(w_{h^*}^Th_t^*+w_s^Ts_t+w_x^Tx_t+b_{ptr})\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 调试PGN模型\n",
    "### PGN初始化参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = load_embedding_matrix()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# params\n",
    "vocab_size = params[\"vocab_size\"]\n",
    "embedding_dim = 300\n",
    "enc_units = dec_units = attn_units = 256\n",
    "batch_size = 64\n",
    "enc_len = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder(vocab_size,\n",
    "               embedding_dim,\n",
    "               embedding_matrix,\n",
    "               enc_units,\n",
    "               batch_size)\n",
    "attention = BahdanauAttention(attn_units)\n",
    "decoder = Decoder(vocab_size,\n",
    "               embedding_dim,\n",
    "               embedding_matrix,\n",
    "               enc_units,\n",
    "               batch_size)\n",
    "pointer = Pointer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PGN.call的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (64,200)\n",
    "enc_inp = enc_data[\"enc_input\"]\n",
    "dec_inp = dec_data[\"dec_input\"]\n",
    "enc_extended_inp = enc_data[\"extended_enc_input\"]\n",
    "batch_oov_len = enc_data[\"max_oov_len\"]\n",
    "enc_pad_mask = enc_data[\"enc_mask\"]\n",
    "use_coverage = True\n",
    "prev_coverage=None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 改为使用tf.TensorArray"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "是为了能使用`@tf.function`加速训练\n",
    "\n",
    "[参考1](https://tensorflow.google.cn/guide/function#batching)\n",
    "\n",
    "[参考2](https://tensorflow.google.cn/tutorials/customization/performance#gotchas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "# tf.TensorArray 代替 list\n",
    "# predictions = []\n",
    "# attentions = []\n",
    "# p_gens = []\n",
    "# coverages = []\n",
    "predictions = tf.TensorArray(tf.float32, size=dec_inp.shape[1])\n",
    "attentions = tf.TensorArray(tf.float32, size=dec_inp.shape[1])\n",
    "p_gens = tf.TensorArray(tf.float32, size=dec_inp.shape[1])\n",
    "coverages = tf.TensorArray(tf.float32, size=dec_inp.shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 修改部分：\n",
    "decoder内置了attention\n",
    "\n",
    "模拟循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_output, enc_hidden = encoder(enc_inp)\n",
    "dec_hidden = enc_hidden\n",
    "prev_coverage = None\n",
    "for t in tf.range(dec_inp.shape[1]):\n",
    "    context_vector, dec_hidden, \\\n",
    "    dec_x, pred, attn, prev_coverage = decoder(dec_inp[:, t],  # (batch_size, )\n",
    "                                        dec_hidden,  # (batch_size, dec_units)\n",
    "                                        enc_output,  # (batch_size, enc_len, enc_units)\n",
    "                                        enc_pad_mask, # (batch_size, enc_len)\n",
    "                                        use_coverage,\n",
    "                                        prev_coverage)\n",
    "    p_gen = pointer(context_vector, dec_hidden, dec_x)\n",
    "    \n",
    "    predictions.write(t, pred)\n",
    "    attentions.write(t, attn)\n",
    "    p_gens.write(t, p_gen)\n",
    "    coverages.write(t, prev_coverage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 32233]),\n",
       " TensorShape([64, 200]),\n",
       " TensorShape([64, 1]),\n",
       " TensorShape([64, 200, 1]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions.element_shape, \\\n",
    "attentions.element_shape, \\\n",
    "p_gens.element_shape, \\\n",
    "coverages.element_shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### for循环之后计算final_dists"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "P(w) = p_{gen}P_{vocab}(w)+(1-P_{gen})\\sum_{i:w_i=w}a_i^t\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgn.model import _calc_final_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([64, 39, 32242])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fd = _calc_final_dist(enc_extended_inp,\n",
    "                     tf.transpose(predictions.stack(), perm=[1, 0, 2]), \n",
    "                      tf.transpose(attentions.stack(), perm=[1, 0, 2]), \n",
    "                      tf.transpose(p_gens.stack(), perm=[1, 0, 2]), \n",
    "                      batch_oov_len, \n",
    "                      vocab_size,\n",
    "                      batch_size)\n",
    "fd.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 调试完毕,调包测试代码\n",
    "修改pgn.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "# from pgn.layers import Encoder, Decoder, BahdanauAttention, Pointer\n",
    "import tensorflow as tf\n",
    "from pgn.batcher import batcher\n",
    "from pgn.model import PGN\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.config import VOCAB_PAD\n",
    "from utils.config_gpu import config_gpu\n",
    "from pgn.model import _calc_final_dist\n",
    "config_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run utils/params.py\n",
    "# 产生输入数据\n",
    "vocab = Vocab(VOCAB_PAD)\n",
    "ds =iter(batcher(vocab, params))\n",
    "enc_data, dec_data = next(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PGN(params)\n",
    "enc_inp = enc_data[\"enc_input\"]\n",
    "dec_inp = dec_data[\"dec_input\"]\n",
    "enc_extended_inp = enc_data[\"extended_enc_input\"]\n",
    "batch_oov_len = enc_data[\"max_oov_len\"]\n",
    "enc_pad_mask = enc_data[\"enc_mask\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调用model.call()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dist,attentions, coverages = model(enc_inp, dec_inp, enc_extended_inp, \n",
    "                       batch_oov_len, enc_pad_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 39, 32237]),\n",
       " TensorShape([64, 39, 200]),\n",
       " TensorShape([64, 39, 200]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_dist.shape, attentions.shape, coverages.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 测试损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgn.loss import loss_function, loss_function2, _coverage_loss, calc_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = dec_data[\"dec_target\"]\n",
    "padding_mask = dec_data[\"dec_mask\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "log loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_function(target, final_dist, padding_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "cov_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$t$是decoder的时间步 (0, dec_len)\n",
    "\n",
    "$i$是encoder的时间步 (0, enc_len)\n",
    "\n",
    "$$\n",
    "covloss_t = \\sum_i min(a_i^t, c_i^t)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=8017, shape=(), dtype=float32, numpy=0.6298077>"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_coverage_loss(attentions, coverages, padding_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: id=8062, shape=(), dtype=float32, numpy=6.849601>,\n",
       " <tf.Tensor: id=8051, shape=(), dtype=float32, numpy=6.534697>,\n",
       " <tf.Tensor: id=8059, shape=(), dtype=float32, numpy=0.6298077>)"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calc_loss(target, final_dist, padding_mask, attentions, coverages, cov_loss_wt=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调用Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_output, enc_hidden = model.encoder(enc_inp)\n",
    "dec_hidden = enc_hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调用Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# enc_output 由Encoder获得\n",
    "# prev_dec_hidden 由Encoder获得\n",
    "# enc_pad_mask 前面有了\n",
    "dec_input = tf.constant([vocab.word2id[vocab.START_DECODING]] * 64)\n",
    "# dec_input = tf.expand_dims([vocab.word2id[vocab.START_DECODING]] * 64, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "context_vector, dec_hidden, \\\n",
    "dec_x, pred, attn, coverage = model.decoder(dec_input, dec_hidden, \n",
    "                                            enc_output, enc_pad_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32233, 64)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_size = params[\"vocab_size\"]\n",
    "batch_size = params[\"batch_size\"]\n",
    "vocab_size,batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 计算单步decoder的final_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_gen = model.pointer(context_vector, dec_hidden, dec_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([64, 1, 32237])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 保证pred attn p_gen的参数为3D的\n",
    "final_dist = _calc_final_dist(enc_extended_inp,\n",
    "                     tf.expand_dims(pred, 1), \n",
    "                     tf.expand_dims(attn, 1), \n",
    "                     tf.expand_dims(p_gen, 1), \n",
    "                      batch_oov_len, \n",
    "                      vocab_size,\n",
    "                      batch_size)\n",
    "final_dist.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 单步decoder封装成函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造第一个dec的输入\n",
    "dec_input = tf.constant([vocab.word2id[vocab.START_DECODING]] * 64)\n",
    "\n",
    "# 事先计算好enc的输出\n",
    "enc_output, enc_hidden = model.encoder(enc_inp)\n",
    "dec_hidden = enc_hidden\n",
    "\n",
    "\n",
    "def decode_one_step(dec_input, dec_hidden, enc_output,\n",
    "                   enc_pad_mask, prev_coverage, use_coverage=True):\n",
    "    # 开始decoder\n",
    "    context_vector, dec_hidden, \\\n",
    "    dec_x, pred, attn, coverage = model.decoder(dec_input, dec_hidden, enc_output,\n",
    "                                  enc_pad_mask, prev_coverage, use_coverage)\n",
    "    \n",
    "    # 计算p_gen\n",
    "    p_gen = model.pointer(context_vector, dec_hidden, dec_x)\n",
    "    \n",
    "    # 保证pred attn p_gen的参数为3D的\n",
    "    final_dist = _calc_final_dist(enc_extended_inp,\n",
    "                         tf.expand_dims(pred, 1), \n",
    "                         tf.expand_dims(attn, 1), \n",
    "                         tf.expand_dims(p_gen, 1), \n",
    "                          batch_oov_len, \n",
    "                          params[\"vocab_size\"],\n",
    "                          params[\"batch_size\"])\n",
    "    \n",
    "    return final_dist, dec_hidden, coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dist, dec_hidden, coverage = decode_one_step(dec_input, dec_hidden, enc_output,\n",
    "                                       enc_pad_mask, use_coverage=True, prev_coverage=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "384.4px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
